{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Built-in Neural Networks in Φ<sub>Flow</sub>\n",
    "\n",
    "Apart from being a general purpose differential physics library, Φ<sub>Flow</sub> also provides a number of backend-agnostic way of setting up optimizers and some neural networks.\n",
    "\n",
    "The following network architectures are supported:\n",
    "\n",
    "* Fully-connected networks: `dense_net()`\n",
    "* Convolutional networks: `conv_net()`\n",
    "* Residual networks: `res_net()`\n",
    "* U-Nets: `u_net()`\n",
    "* Convolutional classifiers: `conv_classifier()`\n",
    "\n",
    "In addition to zero-padding, the convolutional neural networks all support circular periodic padding across feature map spatial dimensions to maintain periodicity.\n",
    "\n",
    "All network-related convenience functions in Φ<sub>Flow</sub> support PyTorch, TensorFlow and Jax/Stax, setting up native networks.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "pycharm": {
     "is_executing": true,
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from phi.tf.flow import *\n",
    "from phi.jax.stax.flow import *\n",
    "from phi.torch.flow import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Fully-connected Networks\n",
    "Fully-connected neural networks are available in Φ<sub>Flow</sub> via `dense_net()`.\n",
    "\n",
    "#### Arguments\n",
    "\n",
    "* `in_channels` : size of input layer, int\n",
    "* `out_channels` = size of output layer, int\n",
    "* `layers` : tuple of linear layers between input and output neurons, list or tuple\n",
    "* `activation` : activation function used within the layers, string\n",
    "* `batch_norm` : use of batch norm after each linear layer, bool"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "net = dense_net(in_channels=1, out_channels=1, layers=[8, 8], activation='ReLU', batch_norm=False)  # Implemented for PyTorch, TensorFlow, Jax-Stax)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Initial loss: \u001b[92m(batchᵇ=100)\u001b[0m \u001b[94m0.246 ± 0.178\u001b[0m \u001b[37m(2e-04...6e-01)\u001b[0m\n",
      "Final loss: \u001b[92m(batchᵇ=100)\u001b[0m \u001b[94m0.087 ± 0.073\u001b[0m \u001b[37m(4e-04...3e-01)\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "optimizer = adam(net, 1e-3)\n",
    "BATCH = batch(batch=100)\n",
    "\n",
    "def loss_function(data: Tensor):\n",
    "    prediction = math.native_call(net, data)\n",
    "    label = math.sin(data)\n",
    "    return math.l2_loss(prediction - label), data, label\n",
    "\n",
    "print(f\"Initial loss: {loss_function(math.random_normal(BATCH))[0]}\")\n",
    "for i in range(100):\n",
    "    loss, _data, _label = update_weights(net, optimizer, loss_function, data=math.random_normal(BATCH))\n",
    "print(f\"Final loss: {loss}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## U-Nets\n",
    "Φ<sub>Flow</sub> provides a built in U-net architecture, classically popular for Semantic Segmentation in Computer Vision, composed of downsampling and upsampling layers.\n",
    "\n",
    " #### Arguments\n",
    "\n",
    "* `in_channels`: input channels of the feature map, dtype : int\n",
    "* `out_channels` : output channels of the feature map, dtype : int\n",
    "* `levels` : number of levels of down-sampling and upsampling, dtype : int\n",
    "* `filters` : filter sizes at each down/up sampling convolutional layer, if the input is integer all conv layers have the same filter size,<br> dtype : int or tuple\n",
    "* `activation` : activation function used within the layers, dtype : string\n",
    "* `batch_norm` : use of batchnorm after each conv layer, dtype : bool\n",
    "* `in_spatial` : spatial dimensions of the input feature map, dtype : int\n",
    "* `use_res_blocks` : use convolutional blocks with skip connections instead of regular convolutional blocks, dtype : bool\n",
    "* `**kwargs` : placeholder for arguments not supported by the function (such as layers in res_net and conv_net)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "net = u_net(in_channels= 1, out_channels= 2, levels=4, filters=16, batch_norm=True, activation='ReLU', in_spatial=2, use_res_blocks=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "#Loss function for training in the network to identify noise parameters\n",
    "def loss_function(scale: Tensor, smoothness: Tensor):\n",
    "    grid = CenteredGrid(Noise(scale=scale, smoothness=smoothness), x=64, y=64)\n",
    "\n",
    "    print(f'Grid Shape : {grid.shape}')\n",
    "    pred_scale, pred_smoothness = field.native_call(net, grid).vector\n",
    "    return math.l2_loss(pred_scale - scale) + math.l2_loss(pred_smoothness - smoothness)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Grid Shape : (examplesᵇ=50, xˢ=64, yˢ=64)\n",
      "Initial loss: \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m8.28e+04 ± 5.6e+04\u001b[0m \u001b[37m(8e+03...2e+05)\u001b[0m\n",
      "Grid Shape : (examplesᵇ=50, xˢ=64, yˢ=64)\n",
      "Iter : 0, Loss : \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m8.28e+04 ± 5.6e+04\u001b[0m \u001b[37m(8e+03...2e+05)\u001b[0m\n",
      "Grid Shape : (examplesᵇ=50, xˢ=64, yˢ=64)\n",
      "Iter : 1, Loss : \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m8.09e+04 ± 5.3e+04\u001b[0m \u001b[37m(9e+03...2e+05)\u001b[0m\n",
      "Grid Shape : (examplesᵇ=50, xˢ=64, yˢ=64)\n",
      "Iter : 2, Loss : \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m7.95e+04 ± 5.1e+04\u001b[0m \u001b[37m(1e+04...2e+05)\u001b[0m\n",
      "Grid Shape : (examplesᵇ=50, xˢ=64, yˢ=64)\n",
      "Iter : 3, Loss : \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m7.85e+04 ± 5.0e+04\u001b[0m \u001b[37m(1e+04...2e+05)\u001b[0m\n",
      "Grid Shape : (examplesᵇ=50, xˢ=64, yˢ=64)\n",
      "Iter : 4, Loss : \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m7.74e+04 ± 4.8e+04\u001b[0m \u001b[37m(1e+04...2e+05)\u001b[0m\n",
      "Grid Shape : (examplesᵇ=50, xˢ=64, yˢ=64)\n",
      "Iter : 5, Loss : \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m7.68e+04 ± 4.7e+04\u001b[0m \u001b[37m(8e+03...2e+05)\u001b[0m\n",
      "Grid Shape : (examplesᵇ=50, xˢ=64, yˢ=64)\n",
      "Iter : 6, Loss : \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m7.60e+04 ± 4.7e+04\u001b[0m \u001b[37m(9e+03...2e+05)\u001b[0m\n",
      "Grid Shape : (examplesᵇ=50, xˢ=64, yˢ=64)\n",
      "Iter : 7, Loss : \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m7.55e+04 ± 4.6e+04\u001b[0m \u001b[37m(1e+04...2e+05)\u001b[0m\n",
      "Grid Shape : (examplesᵇ=50, xˢ=64, yˢ=64)\n",
      "Iter : 8, Loss : \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m7.49e+04 ± 4.6e+04\u001b[0m \u001b[37m(2e+04...2e+05)\u001b[0m\n",
      "Grid Shape : (examplesᵇ=50, xˢ=64, yˢ=64)\n",
      "Iter : 9, Loss : \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m7.41e+04 ± 4.5e+04\u001b[0m \u001b[37m(2e+04...2e+05)\u001b[0m\n",
      "Final loss: \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m7.41e+04 ± 4.5e+04\u001b[0m \u001b[37m(2e+04...2e+05)\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "optimizer = adam(net, learning_rate=1e-3)\n",
    "gt_scale = math.random_uniform(batch(examples=50), low=1, high=10)\n",
    "gt_smoothness = math.random_uniform(batch(examples=50), low=.5, high=3)\n",
    "\n",
    "print(f\"Initial loss: {loss_function(gt_scale, gt_smoothness)}\")\n",
    "for i in range(10):\n",
    "    loss = update_weights(net, optimizer, loss_function, gt_scale, gt_smoothness)\n",
    "    print(f'Iter : {i}, Loss : {loss}')\n",
    "print(f\"Final loss: {loss}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Convolutional Networks\n",
    "Built in Conv-Nets are also provided. Contrary to the classical convolutional neural networks, the feature map spatial size remains the same throughout the layers.\n",
    "Each layer of the network is essentially a convolutional block comprising of two conv layers.\n",
    "A filter size of 3 is used in the convolutional layers.\n",
    "\n",
    "#### Arguments\n",
    "\n",
    "* `in_channels` : input channels of the feature map, dtype : int\n",
    "* `out_channels` : output channels of the feature map, dtype : int <br>\n",
    "* `layers` : list or tuple of output channels for each intermediate layer between the input and final output channels, dtype : list or tuple <br>\n",
    "* `activation` : activation function used within the layers, dtype : string <br>\n",
    "* `batch_norm` : use of batchnorm after each conv layer, dtype : bool <br>\n",
    "* `in_spatial` : spatial dimensions of the input feature map, dtype : int <br>\n",
    "* `**kwargs` : placeholder for arguments not supported by the function\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "net = conv_net(in_channels=1, out_channels=2, layers=[2,4,4,2], activation='ReLU', batch_norm=True, in_spatial=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Grid Shape : (examplesᵇ=50, xˢ=64, yˢ=64)\n",
      "Initial loss: \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m8.21e+04 ± 4.9e+04\u001b[0m \u001b[37m(1e+04...2e+05)\u001b[0m\n",
      "Grid Shape : (examplesᵇ=50, xˢ=64, yˢ=64)\n",
      "Iter : 0, Loss : \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m8.21e+04 ± 4.9e+04\u001b[0m \u001b[37m(1e+04...2e+05)\u001b[0m\n",
      "Grid Shape : (examplesᵇ=50, xˢ=64, yˢ=64)\n",
      "Iter : 1, Loss : \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m8.20e+04 ± 4.9e+04\u001b[0m \u001b[37m(1e+04...2e+05)\u001b[0m\n",
      "Grid Shape : (examplesᵇ=50, xˢ=64, yˢ=64)\n",
      "Iter : 2, Loss : \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m8.18e+04 ± 4.9e+04\u001b[0m \u001b[37m(1e+04...2e+05)\u001b[0m\n",
      "Grid Shape : (examplesᵇ=50, xˢ=64, yˢ=64)\n",
      "Iter : 3, Loss : \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m8.17e+04 ± 4.9e+04\u001b[0m \u001b[37m(1e+04...2e+05)\u001b[0m\n",
      "Grid Shape : (examplesᵇ=50, xˢ=64, yˢ=64)\n",
      "Iter : 4, Loss : \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m8.16e+04 ± 4.9e+04\u001b[0m \u001b[37m(1e+04...2e+05)\u001b[0m\n",
      "Grid Shape : (examplesᵇ=50, xˢ=64, yˢ=64)\n",
      "Iter : 5, Loss : \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m8.15e+04 ± 4.9e+04\u001b[0m \u001b[37m(1e+04...2e+05)\u001b[0m\n",
      "Grid Shape : (examplesᵇ=50, xˢ=64, yˢ=64)\n",
      "Iter : 6, Loss : \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m8.14e+04 ± 4.9e+04\u001b[0m \u001b[37m(1e+04...2e+05)\u001b[0m\n",
      "Grid Shape : (examplesᵇ=50, xˢ=64, yˢ=64)\n",
      "Iter : 7, Loss : \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m8.12e+04 ± 4.9e+04\u001b[0m \u001b[37m(1e+04...2e+05)\u001b[0m\n",
      "Grid Shape : (examplesᵇ=50, xˢ=64, yˢ=64)\n",
      "Iter : 8, Loss : \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m8.11e+04 ± 4.9e+04\u001b[0m \u001b[37m(1e+04...2e+05)\u001b[0m\n",
      "Grid Shape : (examplesᵇ=50, xˢ=64, yˢ=64)\n",
      "Iter : 9, Loss : \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m8.10e+04 ± 4.9e+04\u001b[0m \u001b[37m(1e+04...2e+05)\u001b[0m\n",
      "Final loss: \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m8.10e+04 ± 4.9e+04\u001b[0m \u001b[37m(1e+04...2e+05)\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "optimizer = adam(net, learning_rate=1e-3)\n",
    "gt_scale = math.random_uniform(batch(examples=50), low=1, high=10)\n",
    "gt_smoothness = math.random_uniform(batch(examples=50), low=.5, high=3)\n",
    "\n",
    "print(f\"Initial loss: {loss_function(gt_scale, gt_smoothness)}\")\n",
    "for i in range(10):\n",
    "    loss = update_weights(net, optimizer, loss_function, gt_scale, gt_smoothness)\n",
    "    print(f'Iter : {i}, Loss : {loss}')\n",
    "print(f\"Final loss: {loss}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Residual Networks\n",
    "Built in Res-Nets are provided in the Φ<sub>Flow</sub> framework. Similar to the conv-net, the feature map spatial size remains the same throughout the layers.<br>These networks use residual blocks composed of two conv layers with a skip connection added from the input to the output feature map.<br> A default filter size of 3 is used in the convolutional layers.<br><br>\n",
    "\n",
    "#### Arguments\n",
    "\n",
    "* `in_channels` : input channels of the feature map, dtype : int\n",
    "* `out_channels` : output channels of the feature map, dtype : int\n",
    "* `layers` : list or tuple of output channels for each intermediate layer between the input and final output channels, dtype : list or tuple\n",
    "* `activation` : activation function used within the layers, dtype : string\n",
    "* `batch_norm` : use of batchnorm after each conv layer, dtype : bool\n",
    "* `in_spatial` : spatial dimensions of the input feature map, dtype : int\n",
    "* `**kwargs` : placeholder for arguments not supported by the function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "net = res_net(in_channels=1, out_channels=2, layers=[2,4,4,2], activation='ReLU', batch_norm=True, in_spatial=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Grid Shape : (examplesᵇ=50, xˢ=64, yˢ=64)\n",
      "Initial loss: \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m9.13e+04 ± 5.6e+04\u001b[0m \u001b[37m(3e+04...2e+05)\u001b[0m\n",
      "Grid Shape : (examplesᵇ=50, xˢ=64, yˢ=64)\n",
      "Iter : 0, Loss : \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m9.13e+04 ± 5.6e+04\u001b[0m \u001b[37m(3e+04...2e+05)\u001b[0m\n",
      "Grid Shape : (examplesᵇ=50, xˢ=64, yˢ=64)\n",
      "Iter : 1, Loss : \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m9.10e+04 ± 5.6e+04\u001b[0m \u001b[37m(3e+04...2e+05)\u001b[0m\n",
      "Grid Shape : (examplesᵇ=50, xˢ=64, yˢ=64)\n",
      "Iter : 2, Loss : \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m9.08e+04 ± 5.6e+04\u001b[0m \u001b[37m(3e+04...2e+05)\u001b[0m\n",
      "Grid Shape : (examplesᵇ=50, xˢ=64, yˢ=64)\n",
      "Iter : 3, Loss : \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m9.05e+04 ± 5.6e+04\u001b[0m \u001b[37m(3e+04...2e+05)\u001b[0m\n",
      "Grid Shape : (examplesᵇ=50, xˢ=64, yˢ=64)\n",
      "Iter : 4, Loss : \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m9.04e+04 ± 5.6e+04\u001b[0m \u001b[37m(3e+04...2e+05)\u001b[0m\n",
      "Grid Shape : (examplesᵇ=50, xˢ=64, yˢ=64)\n",
      "Iter : 5, Loss : \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m9.01e+04 ± 5.5e+04\u001b[0m \u001b[37m(3e+04...2e+05)\u001b[0m\n",
      "Grid Shape : (examplesᵇ=50, xˢ=64, yˢ=64)\n",
      "Iter : 6, Loss : \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m9.00e+04 ± 5.6e+04\u001b[0m \u001b[37m(2e+04...2e+05)\u001b[0m\n",
      "Grid Shape : (examplesᵇ=50, xˢ=64, yˢ=64)\n",
      "Iter : 7, Loss : \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m8.97e+04 ± 5.5e+04\u001b[0m \u001b[37m(3e+04...2e+05)\u001b[0m\n",
      "Grid Shape : (examplesᵇ=50, xˢ=64, yˢ=64)\n",
      "Iter : 8, Loss : \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m8.97e+04 ± 5.6e+04\u001b[0m \u001b[37m(2e+04...2e+05)\u001b[0m\n",
      "Grid Shape : (examplesᵇ=50, xˢ=64, yˢ=64)\n",
      "Iter : 9, Loss : \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m8.95e+04 ± 5.6e+04\u001b[0m \u001b[37m(2e+04...2e+05)\u001b[0m\n",
      "Final loss: \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m8.95e+04 ± 5.6e+04\u001b[0m \u001b[37m(2e+04...2e+05)\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "optimizer = adam(net, learning_rate=1e-3)\n",
    "gt_scale = math.random_uniform(batch(examples=50), low=1, high=10)\n",
    "gt_smoothness = math.random_uniform(batch(examples=50), low=.5, high=3)\n",
    "\n",
    "print(f\"Initial loss: {loss_function(gt_scale, gt_smoothness)}\")\n",
    "for i in range(10):\n",
    "    loss = update_weights(net, optimizer, loss_function, gt_scale, gt_smoothness)\n",
    "    print(f'Iter : {i}, Loss : {loss}')\n",
    "print(f\"Final loss: {loss}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Invertible Nets\n",
    "\n",
    "Phiflow also provides invertible neural networks that are capable of inverting the output tensor back to the input tensor initially passed.\\\n",
    "These networks have far reaching applications in predicting input parameters of a problem given its observations.\\\n",
    "Invertible nets are composed of multiple concatenated coupling blocks wherein each such block consists of arbitrary neural networks.\n",
    " \n",
    "Currently these arbitrary neural networks could be set as u_net(default), conv_net, res_net or dense_net blocks with in_channels = out_channels. The architecture used is popularized by [Real NVP](https://arxiv.org/abs/1605.08803).\n",
    "\n",
    "### Arguments\n",
    "* `in_channels` : input channels of the feature map, dtype : int\n",
    "* `num_blocks` : number of coupling blocks inside the invertible net, dtype : int\n",
    "* `activation` : activation function used within the layers, dtype : string\n",
    "* `batch_norm` : use of batchnorm after each layer, dtype : bool\n",
    "* `in_spatial` : spatial dimensions of the input feature map, dtype : int\n",
    "* `net` : type of neural network blocks used in coupling layers, dtype : str\n",
    "* `**kwargs` : placeholder for arguments not supported by the function\n",
    "\n",
    "Note: Currently supported values for net are 'u_net'(default), 'conv_net' and 'res_net'. For choosing 'dense_net' as the network block in coupling layers in_spatial must be set to zero."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def loss_function_inn(grid, scale: Tensor, smoothness: Tensor):\n",
    "    \n",
    "    \n",
    "    pred_scale, pred_smoothness = field.native_call(net, grid).vector\n",
    "    pred_scale = math.expand(pred_scale, channel(c=1))\n",
    "    pred_smoothness = math.expand(pred_smoothness, channel(c=1))\n",
    "    output_grid = math.concat((pred_scale, pred_smoothness), dim='c')\n",
    "    \n",
    "    return math.l2_loss(pred_scale - scale) + math.l2_loss(pred_smoothness - smoothness), output_grid\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = invertible_net(in_channels=2, num_blocks=2, activation='ReLU', batch_norm=True, in_spatial=2, net='u_net')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iter : 0, Loss : \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m2.06e+04 ± 1.3e+04\u001b[0m \u001b[37m(4e+03...5e+04)\u001b[0m\n",
      "Iter : 1, Loss : \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m1.77e+04 ± 1.2e+04\u001b[0m \u001b[37m(3e+03...4e+04)\u001b[0m\n",
      "Iter : 2, Loss : \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m1.56e+04 ± 1.1e+04\u001b[0m \u001b[37m(3e+03...4e+04)\u001b[0m\n",
      "Iter : 3, Loss : \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m1.42e+04 ± 9.9e+03\u001b[0m \u001b[37m(3e+03...4e+04)\u001b[0m\n",
      "Iter : 4, Loss : \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m1.30e+04 ± 9.1e+03\u001b[0m \u001b[37m(3e+03...4e+04)\u001b[0m\n",
      "Iter : 5, Loss : \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m1.19e+04 ± 8.2e+03\u001b[0m \u001b[37m(3e+03...3e+04)\u001b[0m\n",
      "Iter : 6, Loss : \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m1.09e+04 ± 7.4e+03\u001b[0m \u001b[37m(3e+03...3e+04)\u001b[0m\n",
      "Iter : 7, Loss : \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m1.00e+04 ± 6.6e+03\u001b[0m \u001b[37m(3e+03...3e+04)\u001b[0m\n",
      "Iter : 8, Loss : \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m9.18e+03 ± 5.9e+03\u001b[0m \u001b[37m(3e+03...2e+04)\u001b[0m\n",
      "Iter : 9, Loss : \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m8.43e+03 ± 5.3e+03\u001b[0m \u001b[37m(3e+03...2e+04)\u001b[0m\n",
      "Final loss: \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m8.43e+03 ± 5.3e+03\u001b[0m \u001b[37m(3e+03...2e+04)\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "optimizer = adam(net, learning_rate=1e-3)\n",
    "gt_scale = math.random_uniform(batch(examples=50), low=1, high=10)\n",
    "gt_smoothness = math.random_uniform(batch(examples=50), low=.5, high=3)\n",
    "\n",
    "input_grid = CenteredGrid(Noise(scale=gt_scale, smoothness=gt_smoothness), x=32, y=32)\n",
    "    \n",
    "# Expanding channels to 2 by repeating it along channel dimension \n",
    "# in order to obtain feature maps for both pred_scale and pred_smoothness (in_channels = out_channels = 2)\n",
    "input_grid = math.expand(input_grid, channel(c=2))\n",
    "\n",
    "for i in range(10):\n",
    "    loss, grid = update_weights(net, optimizer, loss_function_inn, input_grid, gt_scale, gt_smoothness)\n",
    "    print(f'Iter : {i}, Loss : {loss}')\n",
    "    \n",
    "print(f\"Final loss: {loss}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Original input feature map can be obtained by passing the predicted feature map once again through the n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss between initial input and prediction \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m8.82e+03 ± 6.9e+03\u001b[0m \u001b[37m(7e+02...2e+04)\u001b[0m\n",
      "Loss between initial input and reconstructed input \u001b[92m(examplesᵇ=50)\u001b[0m \u001b[94m4.84e-09 ± 2.3e-09\u001b[0m \u001b[37m(2e-09...1e-08)\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "\n",
    "grid = field.native_call(net, input_grid, False)\n",
    "reconstructed_input =  field.native_call(net, grid, True) # invert = True   \n",
    "print('Loss between initial input and prediction',math.l2_loss(input_grid - grid))\n",
    "print('Loss between initial input and reconstructed input',math.l2_loss(input_grid - reconstructed_input))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.4 ('torch-tf-jax': conda)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  },
  "vscode": {
   "interpreter": {
    "hash": "3188f476244e99b7f99d90d67605181ae6381060b871ae5a45168d18984807c5"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
